{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Theano tensor 模块：索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using gpu device 1: Tesla C2075 (CNMeM is disabled)\n"
     ]
    }
   ],
   "source": [
    "import theano\n",
    "import theano.tensor as T\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简单索引"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`tensor` 模块完全支持 `numpy` 中的简单索引："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 3 5 7]\n"
     ]
    }
   ],
   "source": [
    "t = T.arange(9)\n",
    "\n",
    "print t[1::2].eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`numpy` 结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 3 5 7]\n"
     ]
    }
   ],
   "source": [
    "n = np.arange(9)\n",
    "\n",
    "print n[1::2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## mask 索引"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`tensor` 模块虽然支持简单索引，但并不支持 `mask` 索引，例如这样的做法是<font color=\"red\">错误</font>的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[0 1 2]\n",
      "  [0 1 2]\n",
      "  [0 1 2]]\n",
      "\n",
      " [[0 1 2]\n",
      "  [0 1 2]\n",
      "  [3 4 5]]\n",
      "\n",
      " [[3 4 5]\n",
      "  [3 4 5]\n",
      "  [3 4 5]]]\n"
     ]
    }
   ],
   "source": [
    "t = T.arange(9).reshape((3,3))\n",
    "\n",
    "print t[t > 4].eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`numpy` 中的结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5 6 7 8]\n"
     ]
    }
   ],
   "source": [
    "n = np.arange(9).reshape((3,3))\n",
    "\n",
    "print n[n > 4]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要想像 `numpy` 一样得到正确结果，我们需要使用这样的方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5 6 7 8]\n"
     ]
    }
   ],
   "source": [
    "print t[(t > 4).nonzero()].eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用索引进行赋值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`tensor` 模块不支持直接使用索引赋值，例如 `a[5] = b, a[5]+=b` 等是不允许的。\n",
    "\n",
    "不过可以考虑用 `set_subtensor` 和 `inc_subtensor` 来实现类似的功能："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### T.set_subtensor(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "实现类似 r[10:] = 5 的功能："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "r = T.vector()\n",
    "\n",
    "new_r = T.set_subtensor(r[10:], 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### T.inc_subtensor(x, y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "实现类似 r[10:] += 5 的功能："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "r = T.vector()\n",
    "\n",
    "new_r = T.inc_subtensor(r[10:], 5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
